# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Generic, overload, Tuple, TypeVar

from pyre_extensions import Add, TypeVarTuple
from pyre_extensions.type_variable_operators import Concatenate
from typing_extensions import Literal

A = TypeVar("A", bound=int)
B = TypeVar("B", bound=int)
C = TypeVar("C", bound=int)
D = TypeVar("D", bound=int)

Shape = TypeVarTuple("Shape")
Ts = TypeVarTuple("Ts")

class Tensor(Generic[Shape]):
    def __init__(self, *shape: Shape): ...
    @overload
    def transpose(
        self: "Tensor[Concatenate[A,B,C,Ts]]", d1: Literal[1], d2: Literal[2]
    ) -> "Tensor[Concatenate[A,C,B,Ts]]": ...
    @overload
    def transpose(
        self: "Tensor[Concatenate[A,B,C,D,Ts]]", d1: Literal[2], d2: Literal[3]
    ) -> "Tensor[Concatenate[A,B,D,C,Ts]]": ...
    def contiguous(self) -> "Tensor[Shape]": ...
    # def view(self, *new_shape : Concatenate[Ts1,Literal[-1],Ts2]) -> "Tensor[Concatenate[Ts1,int,Ts2]]" : ...
    @overload
    def view(self, a: A, b: Literal[-1], c: C) -> "Tensor[A,Any,C]": ...
    @overload
    def view(self, a: A, b: Literal[-1], c: C, d: D) -> "Tensor[A,Any,C,D]": ...
    @overload
    def view(self, new_shape: Tuple[Ts]) -> "Tensor[Ts]": ...
    def expand_as(self, as_tensor: Tensor[Ts]) -> "Tensor[Ts]": ...
    def type_as(self, as_tensor: Tensor[Ts]) -> "Tensor[Ts]": ...
    @overload
    def size(self) -> Tuple[Shape]: ...
    @overload
    def size(self: "Tensor[Concatenate[A,B,Ts]]", p: Literal[1]) -> B: ...
    def dim(self) -> int: ...
    def masked_fill(self, mask: Tensor[Shape], value) -> Tensor[Shape]: ...
    def __truediv__(self, a: Any) -> Tensor[Shape]: ...
    def __eq__(self, a: int) -> Tensor[Shape]: ...
    def float(self) -> Tensor[Shape]: ...

@overload
def cat(
    x: Tensor[Concatenate[A, B, C, Ts]],
    y: Tensor[Concatenate[A, B, D, Ts]],
    dim: Literal[2],
) -> Tensor[Concatenate[A, B, Add[C, D], Ts]]: ...
@overload
def matmul(x: Tensor[...], y: Tensor[...]) -> Tensor[...]: ...
